import numpy as np

from .deep.feature_extractor import Extractor
from .sort.nn_matching import NearestNeighborDistanceMetric
from .sort.detection import Detection
from .sort.tracker import Tracker

__all__ = ['DeepSort']


class DeepSort(object):
    def __init__(self, model_path, max_dist=0.2):
        self.min_confidence = 0.3
        self.nms_max_overlap = 1.0
        self.extractor = Extractor("mobilenetv2_x1_0",
                                   model_path,
                                   use_cuda=True)
        max_cosine_distance = max_dist
        nn_budget = 100
        metric = NearestNeighborDistanceMetric("cosine",
                                               max_cosine_distance,
                                               nn_budget)
        self.tracker = Tracker(metric)

    def update(self, bbox_xywh, confidences, cls_ids, ori_img):
        self.height, self.width = ori_img.shape[: 2] # 原始img的长和宽
        features = self._get_features(bbox_xywh, ori_img) # Type: Numpy shape(num_dets, 128)
        bbox_tlwh = self._xywh_to_tlwh(bbox_xywh) # bbox_ltwh format as [x1, y1, w, h]

        detections = [
            Detection(bbox_tlwh[i], conf, cls_ids[i], features[i])
            for i, conf in enumerate(confidences) if conf > self.min_confidence
        ]

        # update tracker
        self.tracker.predict()
        self.tracker.update(detections)

        # output bbox identities
        outputs = []
        for track in self.tracker.tracks:
            if not track.is_confirmed() or track.time_since_update > 1:
                continue
            box = track.to_tlwh()
            cls_id = track.cls_id
            x1, y1, x2, y2 = self._tlwh_to_xyxy(box)
            track_id = track.track_id
            outputs.append(np.array([x1, y1, x2, y2, track_id, cls_id], dtype=np.int))

        if len(outputs) > 0:
            outputs = np.stack(outputs, axis=0)
        return np.array(outputs)

    @staticmethod
    def _xywh_to_tlwh(bbox_xywh):
        bbox_xywh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2] / 2.
        bbox_xywh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3] / 2.
        return bbox_xywh

    def _xywh_to_xyxy(self, bbox_xywh):
        x, y, w, h = bbox_xywh
        x1 = max(int(x - w / 2), 0)
        x2 = min(int(x + w / 2), self.width - 1)
        y1 = max(int(y - h / 2), 0)
        y2 = min(int(y + h / 2), self.height - 1)
        return x1, y1, x2, y2

    def _tlwh_to_xyxy(self, bbox_tlwh):
        x, y, w, h = bbox_tlwh
        x1 = max(int(x), 0)
        x2 = min(int(x + w), self.width - 1)
        y1 = max(int(y), 0)
        y2 = min(int(y + h), self.height - 1)
        return x1, y1, x2, y2

    def _get_features(self, bbox_xywh, ori_img):
        im_crops = []
        for box in bbox_xywh:
            x1, y1, x2, y2 = self._xywh_to_xyxy(box)
            im = ori_img[y1:y2, x1:x2]
            im_crops.append(im)
        if im_crops:
            # 在这里调用并提取embedding
            features = self.extractor(im_crops) # features.shape(num_dets, 128)
        else:
            features = np.array([])
        return features
